/*******************************************************************************
* Copyright 2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef GEMM_PARTITION_HPP
#define GEMM_PARTITION_HPP

#include <array>
#include <cstdint>
#include <tuple>

#include "nstl.hpp"
#include "utils.hpp"

namespace mkldnn {
namespace impl {
namespace cpu {

static inline void partition_1d(const int ithr, const int nthrs, const dim_t n,
        dim_t *t_offset, dim_t *t_block) {

    dim_t band = n / nthrs;

    dim_t tail = n - (nthrs - 1) * band;
    if (tail > (band + 1))
        band++;
    tail = n - (nthrs - 1) * band;

    if (ithr < (nthrs - 1))
        *t_block = band;
    else
        *t_block = tail;

    *t_offset = ithr * band;

    if (*t_offset >= n) {
        *t_block = 0;
        *t_offset = 0;
    } else if ((*t_offset + *t_block) > n) {
        *t_block = n - *t_offset;
    }
}

static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i,
        const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m,
        const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp,
        dim_t *p_n_band) {

    dim_t m_disp = 0, n_disp = 0;
    dim_t m_band = 0, n_band = 0;

    int mdiv = nthrs_m;
    int ndiv = nthrs_n;

    dim_t m_bandt = m / mdiv; /* size per thread */
    dim_t n_bandt = n / ndiv; /* size per thread */
    int firstmgroup = mdiv - 1;
    int firstngroup = ndiv - 1;
    dim_t firstmval = m_bandt;
    dim_t firstnval = n_bandt;

    int mthr_used = mdiv;
    if (m - (mdiv - 1) * m_bandt > m_bandt + 1) {
        if (m - (mdiv - 1) * m_bandt > mdiv)
            ++m_bandt;

        firstmval = m_bandt + 1;
        mthr_used = (int)(m / firstmval);

        if (mthr_used * firstmval < m)
            ++mthr_used;

        firstmgroup = mthr_used - 1;
    }

    int nthr_used = ndiv;
    if (n - (ndiv - 1) * n_bandt > n_bandt + 1) {
        firstnval = n_bandt + 1;
        nthr_used = (int)(n / firstnval);

        if (nthr_used * firstnval < n)
            ++nthr_used;

        firstngroup = nthr_used - 1;
    }

    *nthrs = mthr_used * nthr_used;

    if (ithr < *nthrs) {
        if (ithr_i < firstmgroup) {
            m_band = firstmval;
            m_disp = ithr_i * firstmval;
        } else if (ithr_i <= mthr_used - 2) {
            m_band = m_bandt;
            m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt;
        } else {
            m_disp = firstmgroup * firstmval
                    + (mthr_used - 1 - firstmgroup) * m_bandt;
            m_band = nstl::max(dim_t(0), m - m_disp);
        }

        if (ithr_j < firstngroup) {
            n_band = firstnval;
            n_disp = ithr_j * firstnval;
        } else if (ithr_j <= nthr_used - 2) {
            n_band = n_bandt;
            n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt;
        } else {
            n_disp = firstngroup * firstnval
                    + (nthr_used - 1 - firstngroup) * n_bandt;
            n_band = nstl::max(dim_t(0), n - n_disp);
        }
        m_disp = nstl::max(nstl::min(m_disp, m - 1), dim_t(0));
        n_disp = nstl::max(nstl::min(n_disp, n - 1), dim_t(0));
    }

    if (ithr < *nthrs) {
        *p_m_disp = m_disp;
        *p_n_disp = n_disp;
        *p_m_band = m_band;
        *p_n_band = n_band;
    } else {
        *p_m_disp = 0;
        *p_n_disp = 0;
        *p_m_band = 0;
        *p_n_band = 0;
    }

    return;
}

static inline std::tuple<int, int> partition_2d_minblk_with_primes(int m, int n,
        int block_m, int block_n, int min_m, int min_n, int nthr) {

    auto part_m = nstl::max(1, m / block_m);
    auto part_n = nstl::max(1, n / block_n);

    // Quick exit if there are enough partitions in one direction
    // and there is only 1 partition in the other one
    if (part_m == 1 && part_n >= nthr)
        return std::make_tuple(1, nstl::min(part_n, nthr));

    if (part_n == 1 && part_m >= nthr)
        return std::make_tuple(nstl::min(part_m, nthr), 1);

    auto num_parts = part_m * part_n;

    int nthr_ite = nthr;
    int nthr_m = 1, nthr_n = 1;
    int band_m = m, band_n = n;

    for (auto p : { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29 }) {
        bool finished = false;

        while ((nthr_ite % p) == 0 && !finished) {
            nthr_ite /= p;
            auto nthr_m_ite = nthr_m * p;
            auto nthr_n_ite = nthr_n * p;

            auto band_m_ite = band_m / p;
            auto band_n_ite = band_n / p;

            // Try partitioning with block size bm x bn
            auto try_partition = [&](int bm, int bn, bool pick_small) {
                float ratio_m = (float)band_m_ite / bm;
                float ratio_n = (float)band_n_ite / bn;
                bool do_m = false, do_n = false;

                if (ratio_m < 1. && ratio_n >= 1.)
                    do_n = true;
                else if (ratio_m >= 1. && ratio_n < 1.)
                    do_m = true;
                else if (ratio_m >= 1. && ratio_n >= 1.) {
                    // Pick either the smaller or larger ratio as appropriate.
                    (((ratio_m < ratio_n) == pick_small) ? do_m : do_n) = true;
                }

                if (do_m) {
                    // Partition m.
                    nthr_m = nthr_m_ite;
                    band_m = band_m_ite;
                } else if (do_n) {
                    // Partition n.
                    nthr_n = nthr_n_ite;
                    band_n = band_n_ite;
                }

                return do_m || do_n;
            };

            // If we will need min based partitioning do it now
            if (num_parts < nthr) {
                num_parts *= p;
                if (try_partition(min_m, min_n, true))
                    continue;
            }

            if (try_partition(block_m, block_n, false))
                continue;
            if (try_partition(min_m, min_n, true))
                continue;

            // Both band_m/n are smaller than min_m/n
            // exit the loops, nothing to partition
            finished = true;
        }

        if (finished)
            break;
    }

    return std::make_tuple(nthr_m, nthr_n);
}

static inline std::tuple<int, int> partition_2d_minblk(int m, int n,
        int block_m, int block_n, int min_m, int min_n, int nthr) {

    int part_m = nstl::max(1, m / min_m);
    int part_n = nstl::max(1, n / min_n);

    // Quick exit if one of the dimensions is too small to partition.
    if (part_m == 1) {
        part_n = nstl::max(1, utils::div_up(n, min_n));
        return std::make_tuple(1, nstl::min(part_n, nthr));
    }

    if (part_n == 1) {
        part_m = nstl::max(1, utils::div_up(m, min_m));
        return std::make_tuple(nstl::min(part_m, nthr), 1);
    }

    int nthr_m = 0, nthr_n = 0;
    auto nthr_thresh = nstl::min(0.95 * nthr, (double) (part_m * part_n));

    for (int nthr_new = nthr; nthr_new > nthr / 2; nthr_new--) {
        if (nthr_m * nthr_n >= nthr_thresh)
            break;
        std::tie(nthr_m, nthr_n) = partition_2d_minblk_with_primes(
                m, n, block_m, block_n, min_m, min_n, nthr_new);
    }

    return std::make_tuple(nthr_m, nthr_n);
}

} /* namespace cpu */
} /* namespace impl */
} /* namespace mkldnn */

#endif
